import json
import os

from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

# import required modules 
# import nltk 
# nltk.download('averaged_perceptron_tagger') 
import spacy 
nlp = spacy.load("en_core_web_sm") 

def check_noun(text):
    # ans = nltk.pos_tag([text]) 
    
    # # ans returns a list of tuple 
    # val = ans[0][1] 
    
    # # checking if it is a noun or not 
    # if(val == 'NN' or val == 'NNS' or val == 'NNPS' or val == 'NNP'): 
    #     return True
    # else: 
    #     return False
    # returns a document of object 
    doc = nlp(text) 
    val = doc[0].tag_ 
    
    # checking if it is a noun or not 
    if(val == 'NN' or val == 'NNS' or val == 'NNPS' or val == 'NNP'): 
        return True
    else: 
        return False


class VQADataset(Dataset):
    def __init__(
        self, annotations_path, image_dir_path=None
    ):
        self.annotations = [json.loads(line.strip()) for line in open(annotations_path, "r")]

        self.image_dir_path = image_dir_path
        # self.is_train = is_train
        # self.dataset_name = dataset_name
        # if self.dataset_name in {"vqav2", "ok_vqa"}:
        #     self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
        #     assert self.img_coco_split in {"train2014", "val2014", "test2015"}

    def __len__(self):
        return len(self.annotations)

    def get_img_path(self, question):
        img_file_name = question["file_id"]
        img_coco_split = img_file_name.split("_")[1]
        return os.path.join(
            self.image_dir_path, img_coco_split, img_file_name+".jpg"
            
        )
        
    def __getitem__(self, idx):
        ann = self.annotations[idx]

        question_text = ann["question"]
        question_id = ann["question_id"]
        
        labels = ann["labels"]

        results = {
            "question": question_text,
            "question_id": question_id,
            "labels": labels,
        }

        if self.image_dir_path is not None:
            img_path = self.get_img_path(question)
            image = Image.open(img_path)
            image.load()
            results.update({
                "image": image,
                "img_path": img_path,

            })
        
        for k in ann:
            if k not in results:
                results[k] = ann[k]
        return results

def nltk_detect_noun():
    # split="test"
    split="trainval"
    vqa_val = VQADataset(annotations_path=f"<DATA_FOLDER>/vqa/vqa_k_{split}.jsonl")
    max_iter = min(100000, len(vqa_val))
    num_noun_questions = 0
    total_questions = len(vqa_val)
    res = []
    # for d in tqdm(vqa_val):
    for index in tqdm(range(max_iter)):
        d = vqa_val[index]
        labels = d["labels"]
        question = d["question"].lower()
        if question.startswith("what color"):
            continue
        # sort the dictionary labels by score
        sorted_labels = sorted(labels.items(), key=lambda x: x[1], reverse=True)
        for ans, score in sorted_labels:
            if len(ans.split()) > 1:
                # print(f"ans: {ans} longer than 1 word")
                continue
            if score <= 0.3:
                continue
            
            if check_noun(ans):
                num_noun_questions += 1
                item = d
                item["noun_ans"] = ans
                res.append(item)
                break
    with open(f"./data/vqav2/vqa_k_{split}_noun.jsonl", "w") as f:
        for item in res:
            f.write(json.dumps(item) + "\n")
    print(f"num_noun_questions: {num_noun_questions}")
    print(f"total_questions: {total_questions}")
    print(f"percentage: {num_noun_questions / total_questions * 100 :.2f}%")


import openai
import json
from tqdm import tqdm
import time
import argparse
import os 

with open('credentials/openai_key.txt') as f:
  key = f.readlines()[0].strip()
  api_base = f.readlines()[1].strip()

openai.api_key = key
openai.api_base =api_base 
# openai.api_version = '2022-12-01' # this may change in the future
openai.api_version = "2023-07-01-preview" # this may change in the future
# deployment_id='gpt-4-32k-0314' #This will correspond to the custom name you chose for your deployment when you deployed a model.


prompt = """
You are given a question and an answer based on an image. Return the most relevant object in the image that the question is asking about. 

Here are some examples:
- {"question": "What is the color of the car?",  "answer": "red"} 
  Relevant object: red car
- {"question": "What objects are reflected?",  "answer": "trees"} 
  Relevant object: trees
- {"question": "What brand of bike can you see?",  "answer": "yamaha"} 
  Relevant object: yamaha bike
- {"question": "What is stopping the animals from running away?",  "answer": "wall"} 
  Relevant object: wall
"""


prompt_v1 = """
You are given a question and an answer based on an image. Return the most relevant object in the image that the question is asking about. 

There are some policies to follow:
1. The most relevant object should be the one that when removed from the image, the question would become unanswerable. Here are some examples:
  - {"question": "What is the color of the car?",  "answer": "red"} 
    Relevant object: red car
  - {"question": "What objects are reflected?",  "answer": "trees"} 
    Relevant object: trees
  - {"question": "What brand of bike can you see?",  "answer": "yamaha"} 
    Relevant object: yamaha bike
  - {"question": "What is stopping the animals from running away?",  "answer": "wall"} 
    Relevant object: wall

2. Remember that are limitations in removing object from the image. If the question is regarding the overall presentation of the image, it is impossible to masking out the whole image, so the answer should be na. For example,
  - {"question": "Is this picture taken during the day or night?",  "answer": "day"}
    Relevant object: na
  - {"question": "Is this a house kitchen or a restaurant kitchen?",  "answer": "restaurant"}
    Relevant object: na
  Dont over do it for policy 2, for example,
  - {"question": "Is the rider a child or an adult?", "answer": "adult"}
    Relevant object: adult rider


3. Imagine that even after masking the most relevant object, the question can still be answered, then the answer should be na. For example,
  - {"question": "What is the woman standing on?",  "answer": "floor"}
    Relevant object: na
    Reasoning: we can still reason that she is standing on the floor, given the rest of the context of the image
  - {"question": "What is the person standing on?",  "answer": "ski"}
    Relevant object: na
    Reasoning: we can still reason that he or she is standing on snow, given the rest of the context of the image


4. In the case that there are rich descriptions about the object mentioned in the question, the answer should be the most relevant object that is mentioned in the question, and please try keep the decription intact. For example,
  - {"question": "What does the sign on the door on the bottom right say?",  "answer": "caution"} 
    Relevant object: the caution sign on the door on the bottom right
  - {"question": "What stuffed animal is the child in the red jacket holding?",  "answer": "teddy bear"} 
    Relevant object: teddy bear that the child in the red jacket is holding

5. When the question can be answered, regardless of what is in the image
  - {"question": "Glasses assist in helping what organ?",  "answer": "eyes"} 
    Relevant object: na

6. For questions that are general, please evaluate how often there might be multiple objects belonging to the same category appearing in a scene, and return the most plausible answer. For example,
  - {"question": "What food is presented?",  "answer": "sandwich"} 
    Relevant object: "food"
  - {"question": "What is being eaten?",  "answer": "sandwich"} 
    Relevant object: "food"
"""


def get_noun_obj_in_q(annotation_file="data/vqav2/vqa_k_test_noun.jsonl", deployment_id="gpt4", max_num_retries=5, output_file="data/vqav2/vqa_k_test_noun_gpt4.jsonl", debug=False):
    data = [json.loads(line.strip()) for line in open(annotation_file, "r")]
    if debug:
      data = data[:100]
      output_file = output_file.replace(".jsonl", "_debug.jsonl")
      if os.path.exists(output_file):
        print(f"{output_file} already exists, removing cuz in debug mode")
        os.remove(output_file)

    output = []
    if os.path.exists(output_file):
      print(f"{output_file} already exists, reloading")
      output_data = [json.loads(line.strip()) for line in open(output_file, "r")]
      output_qids = set([d["question_id"] for d in output_data])
    #   return
    f =  open(output_file, "w")
    for d in tqdm(output_data):
        f.write(json.dumps(d) + "\n")

    for i, d in tqdm(enumerate(data), total=len(data)):
        question = d["question"]
        answer = d["noun_ans"]
        qid = d["question_id"]
        if qid in output_qids:
            continue
        messages = [
        {"role": "system", "content": prompt_v1},
        {"role": "user", "content": json.dumps({"question": question, "answer": answer})},
        ]
        tries = 0
        while tries < max_num_retries:
            try:
                response = openai.ChatCompletion.create(
                    engine=deployment_id,
                    messages = messages,
                    temperature=0.7,
                    max_tokens=100,
                    top_p=0.95,
                    frequency_penalty=0,
                    presence_penalty=0,
                    stop=None
                    )
                tries += 1
                content = response['choices'][0]['message']['content'].strip().lower()
                # openai.ChatCompletion.create(
                #     engine=deployment_id,
                #     # model="gpt-4",
                #     max_tokens=100,
                #     temperature=0,
                #     messages=messages)
            except Exception as e:
                str_e = f"{e}"
                if "content management policy" in str_e:
                    print("Skipping due to content management policy")
                    break
                print(f"Failed to call GPT-4 ({e}), sleep 1s")
                time.sleep(1)
                continue
            if not content.startswith("relevant object:"):
                continue
            gpt_ans = content.split("relevant object:")[-1].strip()
            d["gpt_noun_ans"] = gpt_ans
            # output.append(d)
            f.write(json.dumps(d) + "\n")
            break
        if tries == max_num_retries:
            print(f"Failed to call GPT-4 for {qid}")
    f.close()   

if __name__ == '__main__':
  import fire
  fire.Fire()
